# Authors: Emil & Ivetta

# This script generates:
# Tables 14-19 in appendix

# this script does subgroup conjoint analysis
# beware cjplot package needs categorical variables to be factors, otherwise it doesn't understand

pacman::p_load(haven, dplyr, labelled, xlsx, cjoint, sandwich, lmtest, cregg, car, stringr,
               ggtext, glue, ggsci)
#install.packages("ggstance")

#import conjoint df enriched by all the vars from the second wave survey
df <- read_sav("drafts/conjoint_solidarity/perspectives_docs/revision_2/replication_code/input_data.sav")
df <- as_factor(df)

# returns df_mig and df


#setting cool ggtheme
theme_Publication <- function(base_size=18, base_family="Helvetica") {
  library(grid)
  library(ggthemes)
  (theme_foundation(base_size=base_size, base_family=base_family)
    + theme(plot.title = element_text(face = "bold",
                                      size = rel(0.9), hjust = 0.5),
            text = element_text(),
            panel.background = element_rect(colour = NA),
            plot.background = element_rect(colour = NA),
            panel.border = element_rect(colour = NA),
            axis.title = element_text(face = "bold",size = rel(0.8)),
            axis.title.y = element_text(angle=90,vjust =2),
            axis.title.x = element_text(vjust = -0.2),
            axis.text = element_text(), 
            axis.line = element_line(colour="black"),
            axis.ticks = element_line(),
            panel.grid.major = element_line(colour="#f0f0f0"),
            panel.grid.minor = element_blank(),
            legend.key = element_rect(colour = NA),
            legend.position = "bottom",
            legend.direction = "horizontal",
            legend.key.size= unit(0.2, "cm"),
            legend.margin = unit(0, "cm"),
            legend.title = element_text(face="italic"),
            plot.margin=unit(c(10,5,5,5),"mm"),
            strip.background=element_rect(colour="#f0f0f0",fill="#f0f0f0"),
            strip.text = element_text(face="bold")
    ))
  
}


################################################
################################################
################################################
################################################
# Heterogeneity
################################################
################################################
################################################
################################################


# these are dependent variables in a named vector
output_vars <- c("Choice" = "choice")



# These are variables for heterogeneity in a named vector
subgroup_vars <- c(
  
  # Identity and socdem
  "Ethnic minority identity" = "identity_second",
  "Having ukrainian relatives" = "ukrainians",
  
  # War-related feelings
  "Feeling of guilt" = "guilt_bin",
  "Feeling of responsibility" = "responsiblity_bin",
  # Political engagement
  "Political & civic engagement" = "polit_civic_bin",
  "Political interests" = "politics_interest_bin",
  "Host countries" = "country_top"
  )



# Loop over y_vars and x_vars
data <- df_mig

# conjoint attributes formula
cj_formula <- " ~ Age + Gender + Children + Profession + Ethnicity + Motivation"


# Define labels to be bold
bold_labels <- c("Age:", "Gender:", "Children:", "Profession:", "Ethnic background:", "Motivation for emigration:") %>%
  paste0(collapse = "|")

# Custom function to apply bold to specific labels
highlight <- function(x, pat, color="black", family="") {
  # Remove parentheses
  x <- gsub("[()]", "", x)
  
  # Apply bold to labels that match the pattern
  ifelse(grepl(pat, x), glue::glue("<b style='font-family:{family}; color:{color}'>{x}</b>"), x)
}


# to make it a new line on the plot we need to change \n onto <br />

library(forcats)

# Assuming df_mig is your dataframe and Motivation is the factor variable
data <- data %>%
  mutate(Motivation = fct_recode(Motivation,
                                 `Can't accept Russian politics +<br />Was arrested in rallies` = "Can't accept Russian politics +\nWas arrested in rallies"))



# A loop to create all the subgroup graphs

for (y_var in output_vars) {
  for (x_var in subgroup_vars) {
    print(x_var) # for monitoring
    # Create a named vector for renaming (needed so that in graphs there are correct labels)
    label_rename <- c(
      Age = "Age:",
      Gender = "Gender:",
      Children = "Children:",
      Profession = "Profession:",
      Ethnicity = "Ethnic background:",
      Motivation = "Motivation for emigration:"
    )
    
    # Calculate conditional AMCEs
    amces <- cj(data, 
                as.formula(paste0(y_var, cj_formula)), 
                estimate = "amce", 
                by = as.formula(paste0("~", x_var)))
    
    # Calculate differences only if there are not too much levels in subgroup variabless
    if (length(na.omit(unique(data[[x_var]]))) <= 4) {
      diff_amces <- cj(data, 
                       as.formula(paste0(y_var, cj_formula)), 
                       estimate = "amce_diff", 
                       by = as.formula(paste0("~", x_var)))
      amces <- rbind(amces, diff_amces)
    }
    
    # Dataframe is stored directly in the amces object (adjust accordingly if it's nested further)
    amces$feature <- ifelse(amces$feature %in% names(label_rename), 
                           label_rename[amces$feature], 
                           amces$feature)
    
    # Generate plot title
    plot_title <- paste("Choice of potential migrant by respondents subgroups:",
                        #names(output_vars)[which(output_vars == y_var)], 
                        #"by", 
                        names(subgroup_vars)[which(subgroup_vars == x_var)], 
                        "(AMCEs)")
    
    p <- plot(amces,
              legend_title = x_var, 
              size = 2) + ggplot2::facet_wrap(~BY, ncol = 3L) +
      ggtitle(plot_title) +
      ggplot2::theme_bw() +
      ggplot2::theme(legend.position = "none", plot.title = element_text(size = 10)) +
      # theme(
      #   plot.background = element_rect(color = "black", fill = NA, size = 1.5, linetype = "solid"), # Add black border around the plot
      #   panel.background = element_blank() # Ensure that the panel background doesn't overlap with the plot border
      # ) + 
      geom_point(size = 2) +
      scale_color_nejm() +
      scale_y_discrete(labels= function(x) highlight(x, bold_labels, "black")) +
      theme(axis.text.y=element_markdown())
    print(p)
    plot_file_name <- paste0("drafts/conjoint_solidarity/plots/extra_plots/hetero_", x_var, "_", y_var, "_amces", ".jpeg")
    
    # Uncomment for saving
    # ggsave(
    #   plot_file_name,
    #   plot = p,
    #   width = 9,
    #   height = 5.5)
    
    # Calculate conditional MMs
    mms <- cj(data, 
              as.formula(paste0(y_var, cj_formula)), 
              estimate = "mm", 
              by = as.formula(paste0("~", x_var)))
    
    # Calculate differences only if there are not too much levels in subgroup variables
    if (length(na.omit(unique(data[[x_var]]))) <= 4) {
      diff_mms <- cj(data, 
                     as.formula(paste0(y_var, cj_formula)), 
                     estimate = "mm_diff", 
                     by = as.formula(paste0("~", x_var)))
      mms <- rbind(mms, diff_mms)
    }
    
    # Dataframe is stored directly in the mms object (adjust accordingly if it's nested further)
    mms$feature <- ifelse(mms$feature %in% names(label_rename), 
                            label_rename[mms$feature], 
                          mms$feature)
    
    # Generate plot title
    plot_title <- paste("Choice of potential migrant by respondents subgroups:",
                        #names(output_vars)[which(output_vars == y_var)], 
                        #"by", 
                        names(subgroup_vars)[which(subgroup_vars == x_var)], 
                        "(MMs)")
    
    p <- plot(mms,
              legend_title = x_var,
              size = 2) + ggplot2::facet_wrap(~BY, ncol = 3L) +
      ggtitle(plot_title) +
      #geom_vline(xintercept = mean(data[, x_var], na.rm=T)) 
      ggplot2::theme_bw() +
      ggplot2::theme(legend.position = "none", plot.title = element_text(size = 10)) +
      # theme(
      #   plot.background = element_rect(color = "black", fill = NA, size = 1.5, linetype = "solid"), # Add black border around the plot
      #   panel.background = element_blank() # Ensure that the panel background doesn't overlap with the plot border
      # ) +
      geom_point(size = 2) +
      scale_color_nejm() +
      scale_y_discrete(labels= function(x) highlight(x, bold_labels, "black")) +
      theme(axis.text.y=element_markdown())
    
    print(p)
    plot_file_name <- paste0("drafts/conjoint_solidarity/plots/extra_plots/hetero_", x_var, "_", y_var, "_mms", ".jpeg")
    
    # Uncomment for saving
    # ggsave(
    #   plot_file_name,
    #   plot = p,
    #   width = 9,
    #   height = 5.5)
  }
}




################################################
################################################
################################################
################################################
# Produce Latex File with all the names
################################################
################################################
################################################
################################################


# Open a connection to a .txt file for writing
file_conn <- file("drafts/conjoint_solidarity/_subgroups_table_plots_latex_code.txt", "w")
close(file_conn)  # Just opening and closing it in "w" mode will empty it.

file_conn <- file("drafts/conjoint_solidarity/_subgroups_table_plots_latex_code.txt", "w")

# generate_latex_code: A function to generate LaTeX code to insert figures into a LaTeX document.
#
# Args:
# - plot_file_name: The filename of the plot, with the extension (e.g., "plot_amces.png"). 
#                   This will be used for the `\includegraphics` command in LaTeX.
# - title_text: The title that will appear as the caption for the figure in the LaTeX document.
# - type: A descriptor indicating the type of the graph (e.g., "AMCEs" or "MMs"). This is used 
#         to create a note beneath the graph indicating the specific type.
#
# Returns: 
# - This function doesn't return a value. Instead, it writes the generated LaTeX code directly 
#   into a connection (a file or console) defined by 'file_conn' outside the function.


generate_latex_plot_code <- function(plot_file_name, title_text) {
  base_name <- gsub(".jpeg", "", plot_file_name)
  latex_text <- sprintf("
\\begin{figure}[H]
    \\centering
    \\includegraphics[width=\\linewidth]{%s}
    \\caption{\\label{fig:%s} %s}
\\end{figure}
", plot_file_name, base_name, title_text)
  return(latex_text)
}

generate_latex_table_code <- function(data, caption_title, label) {
  table_output <- xtable::xtable(data %>% 
                                   #remove some unnecessary columns from latex output
                                   select(-outcome, -statistic, -lower, -upper, -BY), 
                                 caption = paste0("\\label{tab:", 
                                                  label,
                                                  "}",
                                                  caption_title),
                                 include.rownames = FALSE,
                                 digits = 3)
  table_output <- print(table_output, include.rownames=FALSE, include.colnames=TRUE)
  return(table_output)
}


# Loop over y_vars and x_vars
for (y_var in output_vars) {
  for (x_var in subgroup_vars) {
    
    # Generate subsection title
    subsection_title <- paste0("\\subsection{\\label{subsec:",
                               x_var,
                               "}",
                               "Subgroup Analysis: ",
                               names(subgroup_vars)[which(subgroup_vars == x_var)],
                               "}")
    # Write the subsection title to the file
    cat(subsection_title, sep = "\n", file = file_conn)

    # Calculate conditional AMCEs
    amces <- cj(data, 
                as.formula(paste0(y_var, cj_formula)), 
                estimate = "amce", 
                by = as.formula(paste0("~", x_var)))
   
    
    # Calculate differences only if there are not too much levels in subgroup variabless
    if (length(na.omit(unique(data[[x_var]]))) <= 3) {
      diff_amces <- cj(data, 
                       as.formula(paste0(y_var, cj_formula)), 
                       estimate = "amce_diff", 
                       by = as.formula(paste0("~", x_var)))
      amces <- rbind(amces, diff_amces)
    }
    
    # remove last column that just repeans name of x_var
    amces <- amces %>% select(-x_var)
    
    
    # Calculate conditional MMs
    mms <- cj(data, 
              as.formula(paste0(y_var, cj_formula)), 
              estimate = "mm", 
              by = as.formula(paste0("~", x_var)))
    
    # Calculate differences only if there are not too much levels in subgroup variabless
    if (length(na.omit(unique(data[[x_var]]))) <= 3) {
      diff_mms <- cj(data, 
                     as.formula(paste0(y_var, cj_formula)), 
                     estimate = "mm_diff", 
                     by = as.formula(paste0("~", x_var)))
      mms <- rbind(mms, diff_mms)
    }
    
    # remove last column that just repeans name of x_var
    mms <- mms %>% select(-x_var)
    
    
    
    # Split 'amces' and 'mms' by the "BY" column
    amces_list <- split(amces, amces$BY)
    mms_list <- split(mms, mms$BY)

    
    
    # Generate plot title for AMCEs
    plot_title_amce <- paste("Choice of potential migrant by respondents subgroups:",
                             names(subgroup_vars)[which(subgroup_vars == x_var)], 
                             "(Average Marginal Component Effect).")
    plot_file_name_amce <- paste0("hetero_", x_var, "_", y_var, "_amces", ".jpeg")
    plot_amce_latex <- generate_latex_plot_code(plot_file_name_amce, plot_title_amce)
  
    print(plot_amce_latex)
    # Write the generated LaTeX code to the connection defined by 'file_conn'.
    cat(plot_amce_latex, 
        sep = "\n\n", 
        file = file_conn)
    
    # Iterate over the 'amces_list' and create tables for each unique "BY outcome"
    for (outcome in names(amces_list)) {
      
      # Generate table title for AMCEs
      table_amce_title <- paste("Choice of potential migrant by respondents subgroups:",
                                names(subgroup_vars)[which(subgroup_vars == x_var)], 
                                "(Average Marginal Component Effect for", outcome, ").")
      table_amce_output <- generate_latex_table_code(amces_list[[outcome]], 
                                                     table_amce_title, 
                                                     label = paste0(x_var, 
                                                                    "amce",
                                                                    substring(outcome, 1, 1))
                                                     )
     
      # Write the generated LaTeX code to the connection defined by 'file_conn'.
      cat(table_amce_output,
          sep = "\n\n", 
          file = file_conn)
    }
    
    
    
    # Generate plot title for MMs
    plot_title_mm <- paste("Choice of potential migrant by respondents subgroups:",
                           names(subgroup_vars)[which(subgroup_vars == x_var)], 
                           "(Marginal Means).")
    plot_file_name_mm <- paste0("hetero_", x_var, "_", y_var, "_mms", ".jpeg")
    plot_mm_latex <- generate_latex_plot_code(plot_file_name_mm, plot_title_mm)
    
    
    # Write the generated LaTeX code to the connection defined by 'file_conn'.
    cat(plot_mm_latex, 
        sep = "\n\n", 
        file = file_conn)
    
    # Iterate over the 'amces_list' and create tables for each unique "BY outcome"
    for (outcome in names(amces_list)) {
    
      # Generate table title for MMs
      table_mm_title <- paste("Choice of potential migrant by respondents subgroups:",
                              names(subgroup_vars)[which(subgroup_vars == x_var)], 
                              "(Marginal Means for", outcome, ").")
      table_mm_output <- generate_latex_table_code(mms_list[[outcome]], 
                                                   table_mm_title, 
                                                   label = paste0(x_var, 
                                                                  "mm",
                                                                  substring(outcome, 1, 1)))
      
      # Write the generated LaTeX code to the connection defined by 'file_conn'.
      cat(table_mm_output, 
          sep = "\n\n", 
          file = file_conn)
    }
    
    
    # Write \pagebreak so that after each x_var (subgroup variable) there is a new page
    cat("\\pagebreak", sep = "\n", file = file_conn)
  }
}

close(file_conn) # Close the connection when done
